[None][feat] dual-pool KV cache with SWA block eviction for gemma4#12813
[None][feat] dual-pool KV cache with SWA block eviction for gemma4#12813suyoggupta wants to merge 16 commits intoNVIDIA:mainfrom
Conversation
…IA#12205) Adds Gemma3n custom model with shared KV attention, sliding window attention, and related attention backend changes for AutoDeploy. Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Adds Gemma4 (MoE) custom model for AutoDeploy with: - Custom modeling code supporting K=V attention, proportional RoPE, parallel dense+MoE, per-layer scalars, and logit softcapping - Gelu activation support in torch_moe for Gemma4 MoE layers - Hierarchical equivalence tests - Model registry config (triton_paged attention backend for head_dim=512) Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
…nd tests - Remove incorrect +1.0 scale_shift from Gemma4RMSNorm. HF transformers 5.5.0 stores effective norm weights directly in the checkpoint; the previous implementation incorrectly added 1.0 at load time, causing compounding numerical drift across layers and garbled generation. - Add google/gemma-4-26B-A4B base model registry entry with gemma4_moe_base.yaml config. - Strengthen test_full_model_equivalence with end-to-end logits comparison against standalone reference model. - Add export functional equivalence assertion (pre-export vs post-export). - Update reference _RefRMSNorm to match corrected norm semantics. - Update MoE block test to manually unfuse weights (hook now on decoder layer, not MoE block). Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
…ked prefill Add piecewise CUDA graph compilation, expanded batch sizes, chunked prefill, and KV cache config to both gemma4_moe.yaml and gemma4_moe_base.yaml. Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Add sliding window attention to both decode (FlashDecoding) and context/prefill kernels. When sliding_window is set, queries only attend to the most recent W KV tokens, enabling efficient long-context inference for models with sliding window attention (e.g. Mistral). Key changes: - Decode kernel: restrict page splits to window range, apply per-token window mask, use effective sequence length for split-K heuristic - Context kernel: skip pages before window in Phase 1, add per-query sliding window mask in both Phase 1 (full pages) and Phase 2 (partial/causal pages), guard against NaN from -inf exponents - triton_paged_mha_with_cache: thread sliding_window through to both kernels, add optional pre-allocated output buffer support - Disable SDPA fast-path when sliding window is active - Extract sliding_window constant from source attention node MMLU: 75, GSM8K: 90 Signed-off-by: Suyog Gupta <suyogg@nvidia.com> Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
…fixes, gather logits softcap, sliding window tests - Enable MLIR elementwise fusion, gather_logits, and fuse_gemms transforms in gemma4_moe config; switch gemma4 models to world_size_1 - Register triton_paged ops in piecewise_utils for CUDA graph capture - Add torch.cuda.synchronize after piecewise graph replay to prevent race conditions with non-default streams (e.g. fused_moe) - Fix MLIR triton emitter: use tl.extra.cuda.libdevice for math ops (gelu, tanh, exp, softplus, pow); handle scalar/rank-0 tensor inputs; add AD_DUMP_KERNELS_DIR env var for kernel source inspection - Fix gather_logits_before_lm_head to walk backward through post-lm_head ops (div, tanh, mul softcapping) to find the actual linear node - Add sliding window attention tests for decode and context kernels - Add softcapping LM head test for gather logits transform Signed-off-by: Suyog Gupta <suyogg@nvidia.com> Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
…t Geglu in NVFP4 MoE Pass window_left to fast_decode_plan in plan_generate_only so sliding window attention is respected during CUDA-graph-captured decode. Add early rejection of Gelu/Geglu in NVFP4 TRTLLM-Gen MoE since the underlying kernel does not support it. Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Migrate cuda_graph_batch_sizes to cuda_graph_config.batch_sizes and add explicit max_batch_size to gemma3n config to preserve prior default. Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
📝 WalkthroughWalkthroughIntroduces Gemma 3n and Gemma 4 model support with custom PyTorch implementations, configuration files, and deployment cookbook. Adds shared KV cache and sliding-window attention infrastructure across all attention backends, multi-pool KV cache management, and updates compiler infrastructure for piecewise compilation and metadata handling. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~85 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 8
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py (1)
432-456:⚠️ Potential issue | 🔴 CriticalFix the
mutates_argsdeclaration to match the actual tensor mutations.
torch_backend_mha_with_cache()calls_write_generate_kv_cache()and_update_kv_cache()(lines 498–500) which directly modifyk_cacheandv_cachevia indexed assignment. However, the decorator declaresmutates_args=(), creating a contract mismatch. This will causetorch.compileto misoptimize the cached attention computation.🩹 Proposed fix
-@torch.library.custom_op("auto_deploy::torch_cached_attention_with_cache", mutates_args=()) +@torch.library.custom_op( + "auto_deploy::torch_cached_attention_with_cache", + mutates_args=("k_cache", "v_cache"), +) def torch_backend_mha_with_cache(🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py` around lines 432 - 456, The decorator for torch_backend_mha_with_cache incorrectly declares mutates_args=() while the function mutates k_cache and v_cache via _write_generate_kv_cache and _update_kv_cache; update the `@torch.library.custom_op` on torch_backend_mha_with_cache to list the mutated tensor arguments (k_cache and v_cache) in mutates_args so the op contract matches the actual in-place updates and prevents torch.compile misoptimizations.tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
977-992:⚠️ Potential issue | 🟡 MinorUse the union of groups in the returned KV stats.
In multi-pool mode,
kv_managedhere is only the last group processed by the loop above, sototal_managedand the returnedkv_managedcount under-report earlier pools. The logging becomes misleading հենց when dual-pool mode is enabled; this should usekv_managed_all.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/shim/interface.py` around lines 977 - 992, The returned KV stats and total_managed use only the last group's kv_managed; update the aggregation to use the union across all groups (kv_managed_all) instead. Replace usages of kv_managed when computing total_managed and the "kv_managed" return value with kv_managed_all (i.e., compute total_managed = len(kv_managed_all) + ssm_managed_count + conv_managed_count and return "kv_managed": len(kv_managed_all)); keep other derived counts (paged_total, kv_total, paged_other, other_total) unchanged. Use the existing symbols kv_managed_all, ssm_managed_count, conv_managed_count, total_managed, and the return dict in this function to locate where to apply the change.
🧹 Nitpick comments (7)
tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py (1)
135-137: Consider centralizing FlashInfer op invocation to reduce signature-drift risk.These call-site updates are correct, but the repeated long positional argument list is brittle. A small local helper would make future arity/order changes safer.
Refactor sketch
+def _call_flashinfer_mha_with_cache( + q, k, v, + batch_info_host, qo_indptr_host, paged_kv_indptr, paged_kv_indptr_host, + paged_kv_indices, paged_kv_last_page_len, paged_kv_last_page_len_host, + seq_len_with_cache_host, batch_indices, positions, kv_cache, + k_scale, v_scale, +): + return torch.ops.auto_deploy.flashinfer_attention_mha_with_cache( + q, k, v, + batch_info_host, qo_indptr_host, paged_kv_indptr, paged_kv_indptr_host, + paged_kv_indices, paged_kv_last_page_len, paged_kv_last_page_len_host, + seq_len_with_cache_host, batch_indices, positions, kv_cache, + None, None, k_scale, v_scale, + )Also applies to: 265-267, 396-398, 491-493, 625-627, 784-786, 892-894, 987-989
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py` around lines 135 - 137, This test file repeats a long positional argument list when invoking the FlashInfer attention op, which is brittle; create a small local helper (e.g., call_flashinfer_attention or flashinfer_attention_helper) in tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py that wraps the actual op invocation and accepts either a partial kwargs dict or the same parameters with sensible defaults, then replace each repeated call-site (the clusters around lines shown in the comment) to call that helper instead; update occurrences referenced in the review (around the groups at 135-137, 265-267, 396-398, 491-493, 625-627, 784-786, 892-894, 987-989) so future arity/order changes only need updating in the single helper.tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py (1)
493-496: Remove redundantq_posassignment in the reference helper.Line 493 is overwritten at Line 495, so it is dead code and can be dropped for clarity.
♻️ Proposed cleanup
- q_pos = torch.arange(s_k - s_q + s_q, device=q.device) # absolute positions - # For prefill: q_pos = [0..s_q-1], k_pos = [0..s_k-1] + # For prefill: q_pos = [0..s_q-1], k_pos = [0..s_k-1] q_pos = torch.arange(s_k - s_q, s_k, device=q.device) # [s_q]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py` around lines 493 - 496, The first assignment to q_pos (q_pos = torch.arange(s_k - s_q + s_q, device=q.device)) is dead code because it is immediately overwritten by the later assignment; remove that redundant line and retain the intended prefill assignment (q_pos = torch.arange(s_k - s_q, s_k, device=q.device)) and k_pos assignment (k_pos = torch.arange(s_k, device=q.device)); also update or keep the inline comment to reflect that q_pos now represents absolute positions for prefill.tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py (2)
446-449: Usepytest.skipinstead of early return for CUDA-gated tests.Early
returnexits silently without recording a skip. Usingpytest.skipprovides visibility in test reports and is the standard pattern for conditionally skipping tests.♻️ Suggested fix
def test_flashinfer_shared_kv_cached_attention_reads_aliased_cache_without_writing(): if not torch.cuda.is_available(): - return + pytest.skip("CUDA not available")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py` around lines 446 - 449, The test function test_flashinfer_shared_kv_cached_attention_reads_aliased_cache_without_writing uses an early return when CUDA is unavailable; replace that return with a call to pytest.skip("CUDA is not available") so the test is recorded as skipped, and ensure pytest is imported at the top of the test module if not already present.
96-126: Unused variablebatchshould be prefixed with underscore.Static analysis (RUF059) flags
batchas unused. While this is a common pattern when unpacking tensor shapes, adding an underscore prefix silences the warning and signals intent.♻️ Suggested fix
def _manual_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sliding_window: int | None = None, ) -> torch.Tensor: - batch, seq_len_q, num_heads, _ = q.shape + _batch, seq_len_q, num_heads, _ = q.shape _, seq_len_k, num_kv_heads, _ = k.shape🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py` around lines 96 - 126, The variable `batch` in the `_manual_attention` function is unused and triggers a static analysis warning; change the unpacking from `batch, seq_len_q, num_heads, _ = q.shape` to prefix the unused variable (e.g., `_batch, seq_len_q, num_heads, _ = q.shape`) or otherwise rename it to `_batch` to silence RUF059 and indicate it is intentionally unused; update any references if you choose a different name.tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (1)
857-862: Consider using debug-level logging for SWA eviction details.This log is emitted once per batch (first sequence only), but during high-throughput inference with many batches, INFO-level logs can still be noisy. Consider
ad_logger.debugfor routine operational details, reserving INFO for significant state changes.♻️ Suggested change
if front_removed > 0 and i == 0: # log once per batch, first seq only - ad_logger.info( + ad_logger.debug( f"SWA eviction: group={group_idx} window={window_size} " f"req={request.py_request_id} total_blocks={len(all_indices)} " f"evicted={front_removed} active={num_active} offset={page_offset_g}" )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py` around lines 857 - 862, The SWA eviction message currently uses ad_logger.info and can be noisy; change the log level to debug in the SWA eviction block inside ad_executor.py (the conditional that checks front_removed > 0 and i == 0) so routine eviction details are emitted with ad_logger.debug instead of ad_logger.info, keeping the same message text and context (group_idx, window_size, request.py_request_id, len(all_indices), front_removed, num_active, page_offset_g) to preserve diagnostic data.tests/integration/defs/accuracy/test_llm_api_autodeploy.py (1)
982-984: Mutable class attribute should use a class property or be typed asClassVar.Static analysis flags
EXTRA_EVALUATOR_KWARGSas a mutable default value for a class attribute (RUF012). While this dict is not mutated in practice, it's a minor code smell. Consider usingClassVarannotation or a property.♻️ Optional fix using ClassVar annotation
+from typing import ClassVar + class TestGemma4MoE(LlmapiAccuracyTestHarness): """Bench-run coverage for Gemma4 MoE via AutoDeploy.""" MODEL_NAME = "google/gemma-4-26B-A4B-it" - EXTRA_EVALUATOR_KWARGS = { + EXTRA_EVALUATOR_KWARGS: ClassVar[dict[str, bool]] = { "apply_chat_template": True, }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/integration/defs/accuracy/test_llm_api_autodeploy.py` around lines 982 - 984, EXTRA_EVALUATOR_KWARGS is defined as a mutable class attribute; change it to a ClassVar-annotated constant or convert it into a property to satisfy static analysis: annotate the symbol EXTRA_EVALUATOR_KWARGS as ClassVar[dict[str, Any]] (import ClassVar and Any from typing) or replace it with a `@property` that returns a fresh dict (e.g., def EXTRA_EVALUATOR_KWARGS(self) -> dict[str, Any]: return {"apply_chat_template": True}); update the declaration and imports accordingly.tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
47-165: Type the new wrapper surface before more callers depend on it.
MultiPoolKVCacheManageris a new public API, but most of its methods/properties are unannotated. That makes it harder to use as a drop-inKVCacheManagerreplacement in type-checked code and obscures which methods intentionally diverge from single-pool behavior.As per coding guidelines, "Always annotate functions with type hints" and "Externally called functions must have docstrings; function arguments should be documented, especially for class initializers".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/shim/interface.py` around lines 47 - 165, MultiPoolKVCacheManager is missing type hints and public docstrings; update the class and its public surface to be a proper drop-in KVCacheManager replacement by adding PEP484 type annotations and short docstrings: annotate __init__(self, managers: List[KVCacheManager], primary_idx: int = 0) -> None and document parameters, add return types for properties (e.g., impl -> Any or the actual Impl type, tokens_per_block -> int, max_blocks_per_seq -> int, blocks_in_primary_pool -> int), methods (get_num_free_blocks() -> int, get_max_resource_count() -> int, get_needed_resource_to_completion(request: RequestType) -> int or appropriate type, get_num_kv_blocks(num_tokens: int) -> int, prepare_resources(scheduled_batch: ScheduledBatchType) -> None, free_resources(request: RequestType, pin_on_release: bool = False) -> None, update_resources(...)-> None, add_dummy_requests(request_ids: Sequence[str], **kwargs) -> Any, shutdown() -> None, get_pool(group_idx: int) -> KVCacheManager, num_pools -> int, max_concurrent_sequences -> int, get_buffers(idx: int, kv_layout: str = "NHD") -> BufferType (or raise NotImplementedError with a docstring explaining alternative), event_buffer_max_size -> int, enable_block_reuse -> bool, enable_partial_reuse -> bool, is_draft -> bool, kv_cache_pool_pointers -> PointerType, kv_cache_pool_mapping -> MappingType, get_cache_indices(request: RequestType, **kwargs) -> IndexType, store_blocks_for_reuse(request: RequestType, pin_blocks: bool = False) -> None; include short docstrings on the class and each public method/prop (at least __init__, get_buffers, get_pool, and get_cache_indices) describing behavior and any divergence from single-pool KVCacheManager so callers and type checkers can rely on it.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb`:
- Line 160: The BASE_URL constant is set to a non-routable server bind address
("0.0.0.0"); change BASE_URL to use a routable client endpoint such as
"http://127.0.0.1:8000/v1" or "http://localhost:8000/v1" so client requests
target the local server correctly — locate the BASE_URL assignment in the
gemma_4_trtllm_cookbook notebook (the line containing BASE_URL =
"http://0.0.0.0:8000/v1") and replace the host portion accordingly.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py`:
- Around line 903-905: When you grow the base cache tensor, also grow the
per-group cache tensors so their capacities stay in sync: in the branch that
calls self._input_buffer.resize("cache_loc", estimated_capacity) update every
cache_loc_g* buffer (the attributes created/used by register_window_groups,
e.g., cache_loc_g0, cache_loc_g1, etc.) to the same estimated_capacity (and do
the same mirrored update in the other resize block around lines 919-924). Locate
the places that call self._input_buffer.resize("cache_loc", ...) and add a loop
or explicit resizes to update each cache_loc_g* and any cache_loc_per_group
bookkeeping so staging (cache_loc_per_group) uses the new capacity.
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py`:
- Around line 257-261: The decode-only path drops the sliding-window
(window_left) info, so wire it through: add a sliding_window/ window_left
parameter to prepare_flashinfer_metadata_host() and pass it into
plan_generate_only(), and then forward that value into
flashinfer.decode.fast_decode_plan() (or alternatively detect sliding_window and
route pure-decode batches to plan_decode()); update _to_flashinfer_window_left()
usage to compute the inclusive window_left and ensure
plan_generate_only()/fast_decode_plan() receive that value so SWA decode is
planned correctly.
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`:
- Around line 1151-1163: The current code overwrites seq_len_with_cache with a
window-local length, causing masking to mix global and local coordinates
(kv_page_offset, q_positions_2d, first_q_kv_pos vs kv_base_pos) and leading to
stale/skipped tokens after eviction; instead, preserve the absolute
seq_len_with_cache for masking and only derive a separate local length for
page-iteration bounds: compute cache_len_capped_local =
torch.minimum(cache_len_raw, max_cached) and seq_len_with_cache_local =
cache_len_capped_local + q_lens, keep seq_len_with_cache unchanged, and pass/use
seq_len_with_cache_local solely where page/local bounds are required by
_paged_context_kernel or page-iteration logic (leave masking code using
seq_len_with_cache).
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py`:
- Line 337: The constant list in TrtllmAttention.get_constants() now includes
sink_token_length as the sixth positional constant, but the cached-op signatures
incorrectly place the output buffer parameter (out) between out_scale and
sink_token_length; update the cached-op call/signature(s) so that
sink_token_length remains in the constants prefix (i.e., keep sink_token_length
as the sixth positional constant returned by TrtllmAttention.get_constants())
and move the out parameter after the constants in the cached-op invocation;
apply the same fix to the other cached-op occurrences referenced around the
blocks at the other locations (the similar cached-op signatures near the later
ranges).
In `@tensorrt_llm/_torch/auto_deploy/shim/interface.py`:
- Around line 856-912: The code assumes managers is non-empty and will raise
IndexError for empty kv_groups; before constructing MultiPoolKVCacheManager or
indexing managers[primary_idx], check if managers is empty and handle the no-KV
fallback: if no managers, set self._kv_cache_manager to the appropriate fallback
(e.g., instantiate MambaHybridCacheManager or allocate local resources) instead
of creating MultiPoolKVCacheManager, ensuring the same fallback is used where
managers[primary_idx] would be accessed; update the branch that currently
assigns self._kv_cache_manager (the block that chooses between single manager
and MultiPoolKVCacheManager) to first handle len(managers) == 0, then
len(managers) == 1, then the multi-case, and reference the symbols
_kv_cache_manager, managers, primary_idx, MultiPoolKVCacheManager, and
MambaHybridCacheManager.
In `@tensorrt_llm/_torch/auto_deploy/utils/node_utils.py`:
- Around line 1052-1058: get_op_schema currently picks an arbitrary schema from
multi-overload packets which is non-deterministic; update get_op_schema to (1)
accept explicit types (hint op: Union[torch._ops.OpOverloadPacket,
torch._ops.OpOverload]) via type hints, (2) check for a single-overload
attribute `_schema` first and return it, (3) when `_schemas` is present prefer
and return the `"default"` key if it exists, and (4) if multiple schemas exist
and no `"default"` is present raise a RuntimeError describing the ambiguous
OpOverloadPacket instead of using next(iter(...)); keep the function name
get_op_schema and callers intact.
In `@tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py`:
- Around line 1248-1261: The assertions depend on dict insertion order when
accessing interface._caches; instead, index the caches deterministically using
the resource names passed to add_resource ("kv_0" and "kv_1"). After
interface.initialize_resources(), replace the lookups that use
list(interface._caches.keys())[0/1] with direct access interface._caches["kv_0"]
and interface._caches["kv_1"] (used where kv_0 and kv_1 are assigned) so the
shape assertions reference the correct cache groups reliably.
---
Outside diff comments:
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py`:
- Around line 432-456: The decorator for torch_backend_mha_with_cache
incorrectly declares mutates_args=() while the function mutates k_cache and
v_cache via _write_generate_kv_cache and _update_kv_cache; update the
`@torch.library.custom_op` on torch_backend_mha_with_cache to list the mutated
tensor arguments (k_cache and v_cache) in mutates_args so the op contract
matches the actual in-place updates and prevents torch.compile misoptimizations.
In `@tensorrt_llm/_torch/auto_deploy/shim/interface.py`:
- Around line 977-992: The returned KV stats and total_managed use only the last
group's kv_managed; update the aggregation to use the union across all groups
(kv_managed_all) instead. Replace usages of kv_managed when computing
total_managed and the "kv_managed" return value with kv_managed_all (i.e.,
compute total_managed = len(kv_managed_all) + ssm_managed_count +
conv_managed_count and return "kv_managed": len(kv_managed_all)); keep other
derived counts (paged_total, kv_total, paged_other, other_total) unchanged. Use
the existing symbols kv_managed_all, ssm_managed_count, conv_managed_count,
total_managed, and the return dict in this function to locate where to apply the
change.
---
Nitpick comments:
In `@tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py`:
- Around line 857-862: The SWA eviction message currently uses ad_logger.info
and can be noisy; change the log level to debug in the SWA eviction block inside
ad_executor.py (the conditional that checks front_removed > 0 and i == 0) so
routine eviction details are emitted with ad_logger.debug instead of
ad_logger.info, keeping the same message text and context (group_idx,
window_size, request.py_request_id, len(all_indices), front_removed, num_active,
page_offset_g) to preserve diagnostic data.
In `@tensorrt_llm/_torch/auto_deploy/shim/interface.py`:
- Around line 47-165: MultiPoolKVCacheManager is missing type hints and public
docstrings; update the class and its public surface to be a proper drop-in
KVCacheManager replacement by adding PEP484 type annotations and short
docstrings: annotate __init__(self, managers: List[KVCacheManager], primary_idx:
int = 0) -> None and document parameters, add return types for properties (e.g.,
impl -> Any or the actual Impl type, tokens_per_block -> int, max_blocks_per_seq
-> int, blocks_in_primary_pool -> int), methods (get_num_free_blocks() -> int,
get_max_resource_count() -> int, get_needed_resource_to_completion(request:
RequestType) -> int or appropriate type, get_num_kv_blocks(num_tokens: int) ->
int, prepare_resources(scheduled_batch: ScheduledBatchType) -> None,
free_resources(request: RequestType, pin_on_release: bool = False) -> None,
update_resources(...)-> None, add_dummy_requests(request_ids: Sequence[str],
**kwargs) -> Any, shutdown() -> None, get_pool(group_idx: int) ->
KVCacheManager, num_pools -> int, max_concurrent_sequences -> int,
get_buffers(idx: int, kv_layout: str = "NHD") -> BufferType (or raise
NotImplementedError with a docstring explaining alternative),
event_buffer_max_size -> int, enable_block_reuse -> bool, enable_partial_reuse
-> bool, is_draft -> bool, kv_cache_pool_pointers -> PointerType,
kv_cache_pool_mapping -> MappingType, get_cache_indices(request: RequestType,
**kwargs) -> IndexType, store_blocks_for_reuse(request: RequestType, pin_blocks:
bool = False) -> None; include short docstrings on the class and each public
method/prop (at least __init__, get_buffers, get_pool, and get_cache_indices)
describing behavior and any divergence from single-pool KVCacheManager so
callers and type checkers can rely on it.
In `@tests/integration/defs/accuracy/test_llm_api_autodeploy.py`:
- Around line 982-984: EXTRA_EVALUATOR_KWARGS is defined as a mutable class
attribute; change it to a ClassVar-annotated constant or convert it into a
property to satisfy static analysis: annotate the symbol EXTRA_EVALUATOR_KWARGS
as ClassVar[dict[str, Any]] (import ClassVar and Any from typing) or replace it
with a `@property` that returns a fresh dict (e.g., def
EXTRA_EVALUATOR_KWARGS(self) -> dict[str, Any]: return {"apply_chat_template":
True}); update the declaration and imports accordingly.
In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py`:
- Around line 446-449: The test function
test_flashinfer_shared_kv_cached_attention_reads_aliased_cache_without_writing
uses an early return when CUDA is unavailable; replace that return with a call
to pytest.skip("CUDA is not available") so the test is recorded as skipped, and
ensure pytest is imported at the top of the test module if not already present.
- Around line 96-126: The variable `batch` in the `_manual_attention` function
is unused and triggers a static analysis warning; change the unpacking from
`batch, seq_len_q, num_heads, _ = q.shape` to prefix the unused variable (e.g.,
`_batch, seq_len_q, num_heads, _ = q.shape`) or otherwise rename it to `_batch`
to silence RUF059 and indicate it is intentionally unused; update any references
if you choose a different name.
In
`@tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py`:
- Around line 135-137: This test file repeats a long positional argument list
when invoking the FlashInfer attention op, which is brittle; create a small
local helper (e.g., call_flashinfer_attention or flashinfer_attention_helper) in
tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py
that wraps the actual op invocation and accepts either a partial kwargs dict or
the same parameters with sensible defaults, then replace each repeated call-site
(the clusters around lines shown in the comment) to call that helper instead;
update occurrences referenced in the review (around the groups at 135-137,
265-267, 396-398, 491-493, 625-627, 784-786, 892-894, 987-989) so future
arity/order changes only need updating in the single helper.
In
`@tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py`:
- Around line 493-496: The first assignment to q_pos (q_pos = torch.arange(s_k -
s_q + s_q, device=q.device)) is dead code because it is immediately overwritten
by the later assignment; remove that redundant line and retain the intended
prefill assignment (q_pos = torch.arange(s_k - s_q, s_k, device=q.device)) and
k_pos assignment (k_pos = torch.arange(s_k, device=q.device)); also update or
keep the inline comment to reflect that q_pos now represents absolute positions
for prefill.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 6358dc4d-9231-49e3-ba30-34e14606b639
📒 Files selected for processing (50)
cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cppdocs/source/models/supported-models.mdexamples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynbexamples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yamlexamples/auto_deploy/model_registry/configs/gemma4_moe.yamlexamples/auto_deploy/model_registry/configs/gemma4_moe_base.yamlexamples/auto_deploy/model_registry/models.yamltensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.pytensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.pytensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.pytensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_gated_delta.pytensorrt_llm/_torch/auto_deploy/custom_ops/fla/torch_backend_gated_delta.pytensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.pytensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/causal_conv_common.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/mamba_backend_common.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.pytensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.pytensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.pytensorrt_llm/_torch/auto_deploy/export/export.pytensorrt_llm/_torch/auto_deploy/mlir/codegen/triton_emitter.pytensorrt_llm/_torch/auto_deploy/models/custom/__init__.pytensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.pytensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.pytensorrt_llm/_torch/auto_deploy/shim/ad_executor.pytensorrt_llm/_torch/auto_deploy/shim/interface.pytensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.pytensorrt_llm/_torch/auto_deploy/transform/library/kvcache.pytensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.pytensorrt_llm/_torch/auto_deploy/utils/_graph.pytensorrt_llm/_torch/auto_deploy/utils/node_utils.pytensorrt_llm/_torch/pyexecutor/resource_manager.pytests/integration/defs/accuracy/test_llm_api_autodeploy.pytests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma3n_modeling.pytests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.pytests/unittest/auto_deploy/singlegpu/compile/test_captured_graph.pytests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.pytests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.pytests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.pytests/unittest/auto_deploy/singlegpu/transformations/library/test_gather_logits_before_lm_head.pytests/unittest/auto_deploy/singlegpu/transformations/library/test_kv_cache.py
| "source": [ | ||
| "from openai import OpenAI\n", | ||
| "\n", | ||
| "BASE_URL = \"http://0.0.0.0:8000/v1\"\n", |
There was a problem hiding this comment.
Use a routable client endpoint instead of 0.0.0.0.
Line 160 should use 127.0.0.1 (or localhost) for client requests; 0.0.0.0 is intended for server bind, not client connect.
🔧 Proposed fix
-BASE_URL = "http://0.0.0.0:8000/v1"
+BASE_URL = "http://127.0.0.1:8000/v1"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| "BASE_URL = \"http://0.0.0.0:8000/v1\"\n", | |
| "BASE_URL = \"http://127.0.0.1:8000/v1\"\n", |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb` at line 160,
The BASE_URL constant is set to a non-routable server bind address ("0.0.0.0");
change BASE_URL to use a routable client endpoint such as
"http://127.0.0.1:8000/v1" or "http://localhost:8000/v1" so client requests
target the local server correctly — locate the BASE_URL assignment in the
gemma_4_trtllm_cookbook notebook (the line containing BASE_URL =
"http://0.0.0.0:8000/v1") and replace the host portion accordingly.
| if estimated_capacity > cache_loc_capacity: | ||
| self._input_buffer.resize("cache_loc", estimated_capacity) | ||
|
|
There was a problem hiding this comment.
Keep cache_loc_g* in sync when cache_loc is resized.
register_window_groups() snapshots the current cache_loc capacity into each cache_loc_g*, but Line 904 only grows the base tensor. If max_num_tokens is smaller than max_batch_size * max_blocks_per_seq, staging cache_loc_per_group can start failing even though cache_loc itself was resized successfully.
Possible fix
if estimated_capacity > cache_loc_capacity:
self._input_buffer.resize("cache_loc", estimated_capacity)
+ for group_idx in range(1, self.num_window_groups):
+ self._input_buffer.resize(f"cache_loc_g{group_idx}", estimated_capacity)Also applies to: 919-924
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py` around
lines 903 - 905, When you grow the base cache tensor, also grow the per-group
cache tensors so their capacities stay in sync: in the branch that calls
self._input_buffer.resize("cache_loc", estimated_capacity) update every
cache_loc_g* buffer (the attributes created/used by register_window_groups,
e.g., cache_loc_g0, cache_loc_g1, etc.) to the same estimated_capacity (and do
the same mirrored update in the other resize block around lines 919-924). Locate
the places that call self._input_buffer.resize("cache_loc", ...) and add a loop
or explicit resizes to update each cache_loc_g* and any cache_loc_per_group
bookkeeping so staging (cache_loc_per_group) uses the new capacity.
| def _to_flashinfer_window_left(sliding_window: Optional[int]) -> int: | ||
| """Convert AD sliding-window size to FlashInfer's inclusive window_left contract.""" | ||
| if sliding_window is None or sliding_window <= 0: | ||
| return -1 | ||
| return sliding_window - 1 |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's examine the file structure and find the relevant code
wc -l tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 143
🏁 Script executed:
# Let's read the file to understand the structure
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | head -100Repository: NVIDIA/TensorRT-LLM
Length of output: 4126
🏁 Script executed:
# Look at the _to_flashinfer_window_left function and surrounding context
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '250,270p'Repository: NVIDIA/TensorRT-LLM
Length of output: 879
🏁 Script executed:
# Find the plan_generate_only method
rg -n "def plan_generate_only" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py -A 30Repository: NVIDIA/TensorRT-LLM
Length of output: 1325
🏁 Script executed:
# Search for fast_decode_plan calls
rg -n "fast_decode_plan" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py -B 2 -A 5Repository: NVIDIA/TensorRT-LLM
Length of output: 441
🏁 Script executed:
# Look at PlanParams to see what fields it has
rg -n "class PlanParams" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py -A 20Repository: NVIDIA/TensorRT-LLM
Length of output: 655
🏁 Script executed:
# Look at plan_prefill and plan_decode methods
rg -n "def plan_prefill|def plan_decode|def plan" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 145
🏁 Script executed:
# Get plan_prefill implementation
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '159,200p'Repository: NVIDIA/TensorRT-LLM
Length of output: 2195
🏁 Script executed:
# Get plan_decode implementation
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '200,250p'Repository: NVIDIA/TensorRT-LLM
Length of output: 2549
🏁 Script executed:
# Search for wrapper.plan calls to see how window_left is used
rg -n "wrapper\.plan" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py -B 2 -A 2Repository: NVIDIA/TensorRT-LLM
Length of output: 576
🏁 Script executed:
# Find prepare_flashinfer_metadata_host and see how it uses plan_generate_only
rg -n "prepare_flashinfer_metadata_host" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py -A 50 | head -80Repository: NVIDIA/TensorRT-LLM
Length of output: 3141
🏁 Script executed:
# Check if fast_decode_plan is called elsewhere with window_left
rg -n "fast_decode_plan" tensorrt_llm/ -B 3 -A 10 --type pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1847
🏁 Script executed:
# Look at the full plan_generate_only method to see its complete signature
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '135,158p'Repository: NVIDIA/TensorRT-LLM
Length of output: 1128
🏁 Script executed:
# Check if there are any comments or documentation about window_left support in flashinfer
rg -n "window_left" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 519
🏁 Script executed:
# Check the flashinfer_mha_with_cache function to see complete flow with window_left
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '334,460p'Repository: NVIDIA/TensorRT-LLM
Length of output: 5323
🏁 Script executed:
# Check if there's any indication that fast_decode_plan should support window_left
rg -n "fast_decode_plan" tensorrt_llm --type py -B 5 -A 15 | head -100Repository: NVIDIA/TensorRT-LLM
Length of output: 2564
🏁 Script executed:
# Check if there's documentation or if we need to verify flashinfer API support for window_left in fast_decode_plan
rg -n "causal" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | grep -i "fast_decode"Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Let's check the parameters passed to fast_decode_plan more carefully
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '135,160p'Repository: NVIDIA/TensorRT-LLM
Length of output: 1178
🏁 Script executed:
# Check if there's a causal parameter being passed to fast_decode_plan
rg -n "causal" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py -B 2 -A 2Repository: NVIDIA/TensorRT-LLM
Length of output: 884
🏁 Script executed:
# Look at the full prepare_flashinfer_metadata_host signature and how it's called
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '313,330p'Repository: NVIDIA/TensorRT-LLM
Length of output: 743
🏁 Script executed:
# Check the flow: what calls prepare_flashinfer_metadata_host
rg -n "prepare_flashinfer_metadata_host" tensorrt_llm/ --type py -B 3 -A 3Repository: NVIDIA/TensorRT-LLM
Length of output: 1548
🏁 Script executed:
# Let's verify if flashinfer.decode.fast_decode_plan supports window_left parameter
# by searching for any documentation or other usages in the repo
find tensorrt_llm -name "*.py" -type f -exec grep -l "flashinfer" {} \; | head -10Repository: NVIDIA/TensorRT-LLM
Length of output: 2510
🏁 Script executed:
# Check if there are any other calls to fast_decode_plan with window_left
rg -n "window_left" tensorrt_llm --type pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2720
🏁 Script executed:
# Look for any comments or documentation about flashinfer API usage
rg -n "flashinfer\.decode\.fast_decode_plan" tensorrt_llm -B 5 -A 15Repository: NVIDIA/TensorRT-LLM
Length of output: 2564
🏁 Script executed:
# Check what lines 355-455 refer to in the file
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '355,380p'Repository: NVIDIA/TensorRT-LLM
Length of output: 1173
🏁 Script executed:
# Look for where plan_generate_only should be getting window_left from
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '313,330p'Repository: NVIDIA/TensorRT-LLM
Length of output: 743
🏁 Script executed:
# Check the flow: when is prepare_flashinfer_metadata_host called vs flashinfer_mha_with_cache
rg -n "flashinfer_mha_with_cache|prepare_flashinfer_metadata" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | grep "def "Repository: NVIDIA/TensorRT-LLM
Length of output: 258
🏁 Script executed:
# Look at the full function signature for flashinfer_mha_with_cache to understand the flow better
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '331,365p'Repository: NVIDIA/TensorRT-LLM
Length of output: 1488
🏁 Script executed:
# Check if prepare_flashinfer_metadata_host needs sliding_window parameter
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | sed -n '313,330p'Repository: NVIDIA/TensorRT-LLM
Length of output: 743
Wire window_left through the decode-only planning path.
The sliding-window support is incomplete. While window_left is properly threaded through plan_prefill() and plan_decode() when both phases exist, the decode-only fast path still drops it:
prepare_flashinfer_metadata_host()(line 313) has nosliding_windowparameter, so it cannot pass window information toplan_generate_only()plan_generate_only()callsflashinfer.decode.fast_decode_plan()(line 145) withoutwindow_left, causing SWA decode under CUDA-graph warmup/replay to be planned as full attention
Either:
- Add
sliding_windowparameter toprepare_flashinfer_metadata_host()and pass it through toplan_generate_only(), then tofast_decode_plan(), or - Reroute pure-decode batches through the normal
plan_decode()path instead of theplan_generate_only()fast path when sliding-window is active.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py`
around lines 257 - 261, The decode-only path drops the sliding-window
(window_left) info, so wire it through: add a sliding_window/ window_left
parameter to prepare_flashinfer_metadata_host() and pass it into
plan_generate_only(), and then forward that value into
flashinfer.decode.fast_decode_plan() (or alternatively detect sliding_window and
route pure-decode batches to plan_decode()); update _to_flashinfer_window_left()
usage to compute the inclusive window_left and ensure
plan_generate_only()/fast_decode_plan() receive that value so SWA decode is
planned correctly.
| seq_len_with_cache = seq_len_with_cache_host[:num_prefill].to(q.device, non_blocking=True) | ||
| # For windowed cache_loc (VSWA), cap the cached-token portion of | ||
| # seq_len_with_cache to the actual pages available. Without this, | ||
| # the context kernel computes page iteration bounds from global | ||
| # seq_len, overflowing the windowed cache_loc. | ||
| # seq_len_with_cache = cache_len + q_len, where cache_len is the | ||
| # number of prior-cached tokens. Only cache_len needs capping. | ||
| q_lens = cu_seqlen[1 : num_prefill + 1] - cu_seqlen[:num_prefill] | ||
| page_counts = cu_num_pages[1 : num_prefill + 1] - cu_num_pages[:num_prefill] | ||
| max_cached = page_counts * kv_cache.shape[3] # pages × page_size | ||
| cache_len_raw = seq_len_with_cache - q_lens | ||
| cache_len_capped = torch.minimum(cache_len_raw, max_cached) | ||
| seq_len_with_cache = cache_len_capped + q_lens |
There was a problem hiding this comment.
Don't rewrite seq_len_with_cache into window-local coordinates here.
kv_page_offset makes _paged_context_kernel interpret KV pages in absolute positions, but this cap turns seq_len_with_cache into a local length. After front eviction, q_positions_2d / first_q_kv_pos become local while kv_base_pos stays global, so the causal/SWA masks are evaluated in different coordinate systems. Prefill/extend after eviction can then admit stale tokens from the first retained page or skip valid later pages. Keep the absolute seq_len_with_cache for masking and derive local page bounds separately.
🧰 Tools
🪛 Ruff (0.15.9)
[warning] 1160-1160: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`
around lines 1151 - 1163, The current code overwrites seq_len_with_cache with a
window-local length, causing masking to mix global and local coordinates
(kv_page_offset, q_positions_2d, first_q_kv_pos vs kv_base_pos) and leading to
stale/skipped tokens after eviction; instead, preserve the absolute
seq_len_with_cache for masking and only derive a separate local length for
page-iteration bounds: compute cache_len_capped_local =
torch.minimum(cache_len_raw, max_cached) and seq_len_with_cache_local =
cache_len_capped_local + q_lens, keep seq_len_with_cache unchanged, and pass/use
seq_len_with_cache_local solely where page/local bounds are required by
_paged_context_kernel or page-iteration logic (leave masking code using
seq_len_with_cache).
| kv_scale_quant_orig: float = 1.0, | ||
| out_scale: Optional[torch.Tensor] = None, | ||
| out: Optional[torch.Tensor] = None, | ||
| sink_token_length: int = 0, |
There was a problem hiding this comment.
Keep sink_token_length in the constant prefix.
TrtllmAttention.get_constants() now returns six positional constants, but both cached-op signatures still place out between out_scale and sink_token_length. That makes the sixth constant bind to out, so the transformed TRT-LLM path will either treat an int as the output buffer or never pass sink_token_length through.
🩹 Proposed fix
def trtllm_mha_with_cache(
@@
kv_scale_orig_quant: float = 1.0,
kv_scale_quant_orig: float = 1.0,
out_scale: Optional[torch.Tensor] = None,
- out: Optional[torch.Tensor] = None,
sink_token_length: int = 0,
+ out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
@@
def trtllm_mha_with_cache_fake(
@@
kv_scale_orig_quant: float = 1.0,
kv_scale_quant_orig: float = 1.0,
out_scale: Optional[torch.Tensor] = None,
- out: Optional[torch.Tensor] = None,
sink_token_length: int = 0,
+ out: Optional[torch.Tensor] = None,
) -> torch.Tensor:Also applies to: 535-535, 635-687
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py` at
line 337, The constant list in TrtllmAttention.get_constants() now includes
sink_token_length as the sixth positional constant, but the cached-op signatures
incorrectly place the output buffer parameter (out) between out_scale and
sink_token_length; update the cached-op call/signature(s) so that
sink_token_length remains in the constants prefix (i.e., keep sink_token_length
as the sixth positional constant returned by TrtllmAttention.get_constants())
and move the out parameter after the constants in the cached-op invocation;
apply the same fix to the other cached-op occurrences referenced around the
blocks at the other locations (the similar cached-op signatures near the later
ranges).
| # 2. Create one KVCacheManager per group | ||
| # SWA groups (window < max_seq_len) get fixed max_tokens. | ||
| # Full-attention groups get the remaining budget via max_tokens or free_gpu_mem_fraction. | ||
| managers: List[KVCacheManager] = [] | ||
| primary_idx = 0 # index of the full-attention (largest-window) group | ||
| max_window_seen = 0 | ||
|
|
||
| for group_idx, (kv_ref, kv_managed) in enumerate(kv_groups): | ||
| # Compute this group's token budget | ||
| group_max_tokens = self._compute_group_token_budget( | ||
| group_idx, kv_ref, kv_managed, kv_groups, max_tokens | ||
| ) | ||
| group_config = self._prepare_kv_cache_config(group_max_tokens, kv_managed) | ||
| group_kwargs = self._build_kv_cache_kwargs(kv_ref, kv_managed, group_config) | ||
|
|
||
| # NOTE: SWA groups keep max_seq_len from config (NOT window_size). | ||
| # During prefill, sequences temporarily use up to max_seq_len blocks. | ||
| # max_attention_window evicts old blocks during decode, freeing them | ||
| # for new sequences. The SWA savings are throughput (more concurrent | ||
| # decode sequences), not peak memory reduction. | ||
|
|
||
| if has_state_resources and group_idx == 0: | ||
| group_kwargs["max_batch_size"] = self.info.max_num_state_slots | ||
| mgr, _ = self._create_and_assign_state_views( | ||
| group_kwargs, | ||
| ssm_ref, | ||
| ssm_managed, | ||
| ssm_spec, | ||
| conv_ref, | ||
| conv_managed, | ||
| conv_spec, | ||
| ) | ||
| else: | ||
| mgr = KVCacheManager(**group_kwargs) | ||
|
|
||
| # 3. Create cache manager (delegate to state helper if state resources exist) | ||
| has_state_resources = ssm_managed or conv_managed | ||
| if has_state_resources: | ||
| # NOTE: +1 for cuda graph padding | ||
| kv_cache_kwargs["max_batch_size"] = self.info.max_num_state_slots | ||
| self._kv_cache_manager, _ = self._create_and_assign_state_views( | ||
| kv_cache_kwargs, | ||
| ssm_ref, | ||
| ssm_managed, | ||
| ssm_spec, | ||
| conv_ref, | ||
| conv_managed, | ||
| conv_spec, | ||
| managers.append(mgr) | ||
| is_swa = self._is_swa_group(kv_managed) | ||
| ad_logger.info( | ||
| f"KV pool {group_idx}: {len(kv_managed)} layers, " | ||
| f"head_dim={kv_ref.head_dim}, " | ||
| f"max_attention_window={group_config.max_attention_window}, " | ||
| f"swa={is_swa}, " | ||
| f"max_tokens={group_max_tokens}" | ||
| ) | ||
| else: | ||
| # No typed state resources - use pure KVCacheManager | ||
| self._kv_cache_manager = KVCacheManager(**kv_cache_kwargs) | ||
|
|
||
| # 4. Store tuned config | ||
| self._kv_cache_config_tuned = kv_cache_config | ||
| # Track which group has the largest window (= primary for scheduler) | ||
| group_window = max(group_config.max_attention_window or [self.info.max_seq_len]) | ||
| if group_window > max_window_seen: | ||
| max_window_seen = group_window | ||
| primary_idx = group_idx | ||
|
|
||
| # 3. Store manager (wrapper if multi-group, direct if single) | ||
| if len(managers) == 1: | ||
| self._kv_cache_manager = managers[0] | ||
| else: | ||
| self._kv_cache_manager = MultiPoolKVCacheManager(managers, primary_idx=primary_idx) | ||
|
|
There was a problem hiding this comment.
Handle the no-KV fallback before constructing the wrapper.
When kv_groups is empty, the loop on Lines 863-891 never appends to managers, but Line 911 still constructs a MultiPoolKVCacheManager and Lines 925-928 immediately index managers[primary_idx]. That regresses pure-state or fully unmanaged graphs into an IndexError instead of creating a MambaHybridCacheManager or just allocating local resources.
Also applies to: 925-928
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/auto_deploy/shim/interface.py` around lines 856 - 912,
The code assumes managers is non-empty and will raise IndexError for empty
kv_groups; before constructing MultiPoolKVCacheManager or indexing
managers[primary_idx], check if managers is empty and handle the no-KV fallback:
if no managers, set self._kv_cache_manager to the appropriate fallback (e.g.,
instantiate MambaHybridCacheManager or allocate local resources) instead of
creating MultiPoolKVCacheManager, ensuring the same fallback is used where
managers[primary_idx] would be accessed; update the branch that currently
assigns self._kv_cache_manager (the block that chooses between single manager
and MultiPoolKVCacheManager) to first handle len(managers) == 0, then
len(managers) == 1, then the multi-case, and reference the symbols
_kv_cache_manager, managers, primary_idx, MultiPoolKVCacheManager, and
MambaHybridCacheManager.
| def get_op_schema(op) -> torch.FunctionSchema: | ||
| """Return the schema for an op or op overload packet.""" | ||
| if hasattr(op, "_schemas"): | ||
| return next(iter(op._schemas.values())) | ||
| elif hasattr(op, "_schema"): | ||
| if hasattr(op, "_schema"): | ||
| return op._schema | ||
| raise RuntimeError(f"No schema found on op {op}") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== get_op_schema callsites =="
rg -n --type=py '\bget_op_schema\('
echo
echo "== Direct packet-style calls without explicit .default =="
rg -n --type=py 'get_op_schema\(\s*torch\.ops\.[^)]*\)' | rg -v '\.default\)'
echo
echo "== Overload cardinality for direct torch.ops callsites =="
python - <<'PY'
import pathlib, re, torch
pattern = re.compile(r'get_op_schema\(\s*(torch\.ops\.[A-Za-z0-9_\.]+)\s*\)')
exprs = set()
for p in pathlib.Path(".").rglob("*.py"):
text = p.read_text(encoding="utf-8", errors="ignore")
exprs.update(m.group(1) for m in pattern.finditer(text))
for expr in sorted(exprs):
if expr.endswith(".default"):
continue
try:
op = eval(expr, {"torch": torch})
except Exception as e:
print(f"{expr} -> unresolved ({type(e).__name__})")
continue
schemas = getattr(op, "_schemas", None)
if schemas is None:
print(f"{expr} -> explicit overload/single schema")
else:
print(f"{expr} -> overload_count={len(schemas)} keys={list(schemas.keys())[:8]}")
PYRepository: NVIDIA/TensorRT-LLM
Length of output: 808
🏁 Script executed:
# Inspect each call site context
echo "=== _graph.py:75 ==="
sed -n '70,80p' tensorrt_llm/_torch/auto_deploy/utils/_graph.py
echo
echo "=== export.py:279 ==="
sed -n '274,284p' tensorrt_llm/_torch/auto_deploy/export/export.py
echo
echo "=== kvcache.py:111 ==="
sed -n '106,116p' tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
echo
echo "=== attention_interface.py:71 ==="
sed -n '66,76p' tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
echo
echo "=== node_utils.py:1061-1065 ==="
sed -n '1061,1065p' tensorrt_llm/_torch/auto_deploy/utils/node_utils.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2186
🏁 Script executed:
# Try simpler Python verification of torch.ops overload counts
python3 <<'PYEOF'
import sys
import torch
# Test a few common torch.ops to see what structure they have
test_ops = [
("torch.ops.aten.add", torch.ops.aten.add),
]
for name, op_ref in test_ops:
has_schemas = hasattr(op_ref, "_schemas")
has_schema = hasattr(op_ref, "_schema")
print(f"{name}:")
print(f" has _schemas: {has_schemas}")
print(f" has _schema: {has_schema}")
if has_schemas:
schemas = op_ref._schemas
print(f" _schemas keys: {list(schemas.keys())[:10]}")
print(f" schema count: {len(schemas)}")
if has_schema:
print(f" _schema: {op_ref._schema}")
PYEOFRepository: NVIDIA/TensorRT-LLM
Length of output: 180
🏁 Script executed:
# Search for how node.target gets assigned and what type it can be
echo "=== Searching for node.target assignments/usage ==="
rg -n 'node\.target' tensorrt_llm/_torch/auto_deploy/utils/node_utils.py | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 1402
🏁 Script executed:
# Understand what type node.target can be when op="call_function"
echo "=== FX Node type hint documentation ==="
rg -A5 'node.op.*call_function' tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
echo
echo "=== kvcache.py context for prep_meta_op ==="
sed -n '90,115p' tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
echo
echo "=== Search for call_function node creation ==="
rg -n 'call_function' tensorrt_llm/_torch/auto_deploy/utils/node_utils.py | head -15Repository: NVIDIA/TensorRT-LLM
Length of output: 2503
🏁 Script executed:
# Look at the actual Node type and understand node.target semantics
echo "=== Node.target type from fx imports ==="
rg -n 'from torch' tensorrt_llm/_torch/auto_deploy/utils/node_utils.py | head -10
echo
echo "=== Check if there's type narrowing logic ==="
rg -B2 -A5 'node.op == "call_function"' tensorrt_llm/_torch/auto_deploy/utils/node_utils.py | head -30Repository: NVIDIA/TensorRT-LLM
Length of output: 588
🏁 Script executed:
# Search export.py for node.target context when get_op_schema is called
echo "=== export.py context around line 279 ==="
sed -n '265,290p' tensorrt_llm/_torch/auto_deploy/export/export.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 967
🏁 Script executed:
# Understand what is_op does and if it narrows the type
echo "=== is_op function definition ==="
sed -n '450,480p' tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
echo
echo "=== Check if node.target from is_op match is always OpOverload ==="
rg -B5 -A10 'def is_op' tensorrt_llm/_torch/auto_deploy/utils/node_utils.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1661
🏁 Script executed:
# Check the actual behavior: what does get_prepare_extra_metadata_info return?
echo "=== attn_descriptor.get_prepare_extra_metadata_info return type ==="
rg -n 'get_prepare_extra_metadata_info' tensorrt_llm/_torch/auto_deploy/transform/library/
echo
echo "=== Look for its definition ==="
rg -B3 -A15 'def get_prepare_extra_metadata_info' tensorrt_llm/_torch/auto_deploy/Repository: NVIDIA/TensorRT-LLM
Length of output: 13016
🏁 Script executed:
# Most importantly: has the _schemas branch ever been executed?
# Look for test coverage or usage of OpOverloadPacket with _schemas
echo "=== Search for _schemas usage ==="
rg -n '\._schemas' tensorrt_llm/_torch/auto_deploy/
echo
echo "=== Search for OpOverloadPacket direct passing ==="
rg -n 'OpOverloadPacket' tensorrt_llm/_torch/auto_deploy/Repository: NVIDIA/TensorRT-LLM
Length of output: 5458
Avoid selecting an arbitrary schema from multi-overload packets in get_op_schema().
When node.target is an OpOverloadPacket (possible since is_op() explicitly matches both OpOverload and OpOverloadPacket), next(iter(op._schemas.values())) at line 1055 selects a non-deterministic schema, which can cause argument lookup failures in extract_op_args() and set_op_args().
The most direct risk is in export.py:279, where after is_op(node, moe_ops) (with OpOverloadPacket objects), node.target is passed directly without normalizing to an explicit overload.
Proposed fix:
- Reorder checks to prefer
_schema(single overload) before_schemas - When handling multi-overload packets, prefer
"default"key if present - Raise an error for ambiguous packets rather than selecting arbitrarily
- Add type hints to clarify accepted input types
Suggested implementation
-def get_op_schema(op) -> torch.FunctionSchema:
- """Return the schema for an op or op overload packet."""
- if hasattr(op, "_schemas"):
- return next(iter(op._schemas.values()))
- if hasattr(op, "_schema"):
- return op._schema
+def get_op_schema(op: OpOverloadPacket | OpOverload) -> torch.FunctionSchema:
+ """Return the schema for an op overload or overload packet."""
+ if hasattr(op, "_schema"):
+ return op._schema
+ if hasattr(op, "_schemas"):
+ schemas = op._schemas
+ if "default" in schemas:
+ return schemas["default"]
+ if len(schemas) == 1:
+ return next(iter(schemas.values()))
+ raise RuntimeError(
+ f"Ambiguous overload packet {op}: multiple overloads {list(schemas.keys())}. "
+ "Pass an explicit overload (e.g., op.default)."
+ )
raise RuntimeError(f"No schema found on op {op}")
-def _get_op_schema(node: Node):
+def _get_op_schema(node: Node) -> torch.FunctionSchema:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/auto_deploy/utils/node_utils.py` around lines 1052 -
1058, get_op_schema currently picks an arbitrary schema from multi-overload
packets which is non-deterministic; update get_op_schema to (1) accept explicit
types (hint op: Union[torch._ops.OpOverloadPacket, torch._ops.OpOverload]) via
type hints, (2) check for a single-overload attribute `_schema` first and return
it, (3) when `_schemas` is present prefer and return the `"default"` key if it
exists, and (4) if multiple schemas exist and no `"default"` is present raise a
RuntimeError describing the ambiguous OpOverloadPacket instead of using
next(iter(...)); keep the function name get_op_schema and callers intact.
| interface.add_resource("kv_0", KVPagedResourceHandler(8, 64, dtype=torch.float16)) | ||
| interface.add_resource("kv_1", KVPagedResourceHandler(4, 128, dtype=torch.float16)) | ||
|
|
||
| interface.initialize_resources() | ||
|
|
||
| # Group 0 (head_dim=64): cache shape [..., 8, 32, 64] | ||
| kv_0 = interface._caches[list(interface._caches.keys())[0]] | ||
| assert kv_0.shape[-1] == 64 | ||
| assert kv_0.shape[-3] == 8 # num_kv_heads | ||
|
|
||
| # Group 1 (head_dim=128): cache shape [..., 4, 32, 128] | ||
| kv_1 = interface._caches[list(interface._caches.keys())[1]] | ||
| assert kv_1.shape[-1] == 128 | ||
| assert kv_1.shape[-3] == 4 # num_kv_heads |
There was a problem hiding this comment.
Avoid dict-order dependence in cache-shape assertions.
Line 1254 and Line 1259 rely on _caches key order. This can make the test flaky if insertion or registration order changes. Use the returned resource names from add_resource(...) for deterministic lookup.
✅ Suggested deterministic lookup
- interface.add_resource("kv_0", KVPagedResourceHandler(8, 64, dtype=torch.float16))
- interface.add_resource("kv_1", KVPagedResourceHandler(4, 128, dtype=torch.float16))
+ kv_0_name = interface.add_resource("kv_0", KVPagedResourceHandler(8, 64, dtype=torch.float16))
+ kv_1_name = interface.add_resource("kv_1", KVPagedResourceHandler(4, 128, dtype=torch.float16))
@@
- kv_0 = interface._caches[list(interface._caches.keys())[0]]
+ kv_0 = interface._caches[kv_0_name]
@@
- kv_1 = interface._caches[list(interface._caches.keys())[1]]
+ kv_1 = interface._caches[kv_1_name]🧰 Tools
🪛 Ruff (0.15.9)
[warning] 1254-1254: Prefer next(iter(interface._caches.keys())) over single element slice
Replace with next(iter(interface._caches.keys()))
(RUF015)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/unittest/auto_deploy/singlegpu/shim/test_cached_sequence_interface.py`
around lines 1248 - 1261, The assertions depend on dict insertion order when
accessing interface._caches; instead, index the caches deterministically using
the resource names passed to add_resource ("kv_0" and "kv_1"). After
interface.initialize_resources(), replace the lookups that use
list(interface._caches.keys())[0/1] with direct access interface._caches["kv_0"]
and interface._caches["kv_1"] (used where kv_0 and kv_1 are assigned) so the
shape assertions reference the correct cache groups reliably.
9917059 to
e4da73b
Compare
…sliding window attention Adds dual-pool KV cache architecture for models with mixed attention types (e.g., gemma4-26B with head_dim=256 sliding + head_dim=512 full attention). Each head_dim group gets its own KVCacheManager pool with independent max_attention_window, enabling SWA block eviction during decode. Architecture: WindowPlan is the single source of truth for VSWA. It separates logical attention-window routing (which layers share page tables) from physical KV storage pooling (which layers share block pools). Both graph wiring and runtime metadata emission derive from it, eliminating predicate drift between the transform and executor. Key changes: - WindowPlan dataclass: per_layer_window, unique_windows, group indices, group_to_pool_idx mapping (decouples window groups from storage pools) - MultiPoolKVCacheManager: delegates lifecycle to all storage pools - _identify_managed_kv_groups: groups layers by (head_dim, dtype, kv_factor) - Per-group cache_loc/cu_num_pages/kv_page_offset via VSWA graph wiring - kv_page_offset in write kernel for window-relative page indexing - kv_page_offset in context kernel for correct position-based masking - cache_len capping from cu_num_pages in triton_paged_mha_with_cache - get_num_front_blocks_removed C++ binding for SWA eviction tracking - N-based proportional memory budget split across pools - max_concurrent_sequences scheduler cap for multi-pool safety - Unit tests for multi-group identification, dual-pool creation, and per-group max_attention_window scoping Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
lucaslie
left a comment
There was a problem hiding this comment.
Okay, here's my overall thought. I think this is a very nice design, and it makes a lot of sense. We will start needing to support different groups, including different groups for metadata.
That being said, maybe this is just personal taste, but I do think that we could change the design slightly in order to clean up some of the dependencies.
- To begin with, I would love for all of the cache-related to be encapsulated into the resource handler classes. This way, there's a very clean interface between individual layer-wise attention operators and our Cache Management Interface in the shim.
- Next up is what I think is the most important part of the change in the shim. Right now, instead of just naively collecting all the resource handlers, we might have to do some analysis of them on the fly. In particular, we have to on the fly analyze when we can put something into an existing KV manager and when we have to initialize a new group with a new KV manager. This design, regardless of sliding window or other features inside attention, should prove to be very scalable in the future as well.
- As part of that, we can also initialize a new group of metadata fields from the cache sequence interface in the attention interface. Ideally, we can carry that concept of a group over to the attention interface and use a standardized way to initialize a new group of metadat for that new KV Cache Manager.
- Now that all this information is in place, we can go back to the KV Cache Transform. When the KV Cache Transform requests certain metadata arguments, we can now return the nodes/metadata inputs that correspond to the particular group. This way, we can dynamically insert the correct group.
- Now the final step to tie it all together is, of course, in prepare inputs, where for each of the KV cache managers (where each KV cache manager corresponds to a different group) we tie it all together and we prep the metadata for all the groups and pass it into the attention interface.
What do you think of this, given that the change we're introducing here is very heavy? It might be worth digging a little deeper.
| @classmethod | ||
| def get_constants(cls, source_attn_node: Node) -> List[Constant]: | ||
| def get_constants( | ||
| cls, source_attn_node: Node, cache_config: Optional["KvCacheConfig"] = None |
There was a problem hiding this comment.
Why can the sink token length not be extracted from the node just like the other constants?
There was a problem hiding this comment.
This is a full device sync.
Sync-ing a specified stream will be a more fine-grain approach and avoids destroying GPU pipelining. I think we should do a stream-sync here.
There was a problem hiding this comment.
May you explain the rationale for changing the comment here. Thank you.
There was a problem hiding this comment.
If kv_managed is empty, this returns true. Need to guard with assertion here.
There was a problem hiding this comment.
If kv_idx does not overlap, the function also returns true. We should guard this as well.
There was a problem hiding this comment.
We should account for max_batch_size here.
Summary
KVCacheManagerpool with independentmax_attention_window, enabling SWA block eviction during decodeMultiPoolKVCacheManagerwrapper provides unified API for lifecycle, scheduling, and block retrievalkv_page_offsetsupport in triton write and context kernels for correct windowed cache_loc indexingget_num_front_blocks_removedbinding for SWA eviction trackingTest plan
build_and_run_ad.pywith gemma4-26B-A4B-it chat template — coherent outputTestGemma4MoE::test_bf16) — 75.6% matching baseline 75.4%SWA eviction: group=0 window=1024 ... evicted=NStacked on #12710
🤖 Generated with Claude Code
Summary by CodeRabbit
Release Notes
New Features
Documentation
Tests